# Lint as: python3
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A binary to generate random options data.

Note that Google Cloud SDK, storage client library and tf_quant_finance should
already be installed (see `requirements.txt`). In case the generated data needs
to be written to GCS, authentication steps must be performed prior to running
this script as described at:
https://cloud.google.com/storage/docs/reference/libraries

This script is meant to be run from inside a docker container
(see the Dockerfile). If running outside a container, the PYTHONPATH must be
modified to include the parent folder so the modules in the `common` folder are
visible. For example, `PYTHONPATH=$PYTHONPATH:/my/path/option_pricing_basic/`.
"""

import datetime
from os import path
import tempfile
from typing import Tuple, List

from absl import app
from absl import flags
from absl import logging
from common import datatypes
import dataclasses
import numpy as np
import tf_quant_finance as tff

from google.cloud import storage

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'output_path', '',
    'Location where the output data is stored. If the path '
    'begins with gs:// it is assumed to be GCS. If a GCS path is '
    'specified, temporary local files will be created in a platform '
    'specific temporary directory.')
flags.DEFINE_integer('num_underliers', 1000, 'Number of different underliers '
                     'for the options.')
flags.DEFINE_integer('options_per_file', 1000000,
                     'Number of options in each file.')
flags.DEFINE_integer('num_files', 50, 'Number of files to generate.')
flags.DEFINE_string(
    'output_file_prefix', 'portfolio',
    'Prefix for generating portfolio file names. The filename will '
    'be generated by adding a number to the prefix '
    '(e.g. portfolio_1-of-20.tfrecords).')
flags.DEFINE_string(
    'market_data_file_name', 'market_data.tfrecords',
    'Name of the market data file. '
    'It will be stored in the same location as the portfolio file.')
flags.DEFINE_integer('random_seed', None,
                     'An optional random seed to control the data generation.')


def is_gcs_path(gcs_path: str) -> bool:
  return gcs_path.startswith('gs://')


def generate_market_data(num_underliers: int) -> datatypes.OptionMarketData:
  spots = np.random.rand(num_underliers) * 900 + 100  # Between 100.0 and 1000.0
  volatility = np.random.rand(
      num_underliers) * 0.57 + 0.03  # Between 3% and 60%
  rate = np.random.rand(num_underliers) * 0.15  # Between 0% and 15%.
  return datatypes.OptionMarketData(
      underlier_id=np.arange(num_underliers),
      spot=spots,
      volatility=volatility,
      rate=rate)


def generate_portfolio(num_instruments: int,
                       market_data: datatypes.OptionMarketData,
                       start_instrument_id: int = 0) -> datatypes.OptionBatch:
  """Generates a random portfolio."""
  underlier_ids = np.random.choice(
      market_data.underlier_id, size=num_instruments)
  # Choose strikes to be within +/- 40% of the spot.
  call_put_flags = np.random.rand(num_instruments) > 0.5
  # datetime.date.toordinal uses the same ordinals as tff's dates module.
  today = datetime.date.today().toordinal()
  expiry_date = today + np.random.randint(
      1, high=365 * 10, size=num_instruments)
  strike_multipliers = np.random.rand(num_instruments) * 0.8 + 0.6
  underlier_spots = market_data.spot[underlier_ids]
  strikes = underlier_spots * strike_multipliers
  trade_ids = np.arange(num_instruments) + start_instrument_id
  return datatypes.OptionBatch(
      strike=strikes,
      call_put_flag=call_put_flags,
      expiry_date=expiry_date,
      trade_id=trade_ids,
      underlier_id=underlier_ids)


def split_bucket_path(gcs_path: str) -> Tuple[str, str]:
  gcs_path = gcs_path.replace('gs://', '')  # Remove the prefix.
  pieces = gcs_path.split('/')
  bucket_name = pieces[0]
  remainder = '' if len(pieces) == 1 else path.join(*pieces[1:])
  return bucket_name, remainder


def upload_to_gcs(gcs_path: str, local_files: List[str]) -> None:
  """Uploads data to a gcs bucket."""
  client = storage.Client()
  bucket_name, gcs_path = split_bucket_path(gcs_path)
  bucket = client.bucket(bucket_name)
  for local_path in local_files:
    file_name = path.basename(local_path)
    full_blob_name = path.join(gcs_path, file_name)
    target_blob = bucket.blob(full_blob_name)
    target_blob.upload_from_filename(local_path)
    logging.info('Uploaded %s to gs://%s/%s.',
                 local_path, bucket_name, full_blob_name)


def write_local(local_file_path: str, data_class: object) -> None:
  """Writes a dataclass object to file."""
  if not dataclasses.is_dataclass(data_class):
    raise ValueError('Object to write must be a dataclass')
  data = dataclasses.asdict(data_class)
  with tff.experimental.io.ArrayDictWriter(local_file_path) as writer:
    writer.write(data)


def main(argv):
  """Generates synthetic pricing data."""
  del argv
  np.random.seed(seed=FLAGS.random_seed)
  # Generate and write market data.
  output_path = FLAGS.output_path
  local_directory = (tempfile.gettempdir() if is_gcs_path(output_path)
                     else output_path)
  market_data_file_path = path.join(local_directory,
                                    FLAGS.market_data_file_name)
  num_underliers = FLAGS.num_underliers
  market_data = generate_market_data(num_underliers)
  write_local(market_data_file_path, market_data)
  logging.info('Wrote market data to %s', market_data_file_path)
  # Generate the portfolio conditioned on the market data.
  num_files = FLAGS.num_files
  output_file_prefix = FLAGS.output_file_prefix
  file_name_gen = (
      lambda i: f'{output_file_prefix}_{i+1}-of-{num_files}.tfrecords')
  local_full_path_gen = lambda i: path.join(local_directory, file_name_gen(i))

  local_paths = []
  options_per_file = FLAGS.options_per_file
  for i in range(num_files):
    start_id = i * options_per_file
    portfolio = generate_portfolio(
        options_per_file, market_data, start_instrument_id=start_id)
    local_path = local_full_path_gen(i)
    write_local(local_path, portfolio)
    logging.info('Wrote portfolio shard to %s', local_path)
    local_paths.append(local_path)

  # Upload stuff to GCS
  if is_gcs_path(output_path):
    logging.info('Uploading data to gcs')
    upload_to_gcs(output_path, [market_data_file_path])
    upload_to_gcs(output_path, local_paths)


if __name__ == '__main__':
  app.run(main)
